// SPDX-License-Identifier: LGPL-3.0-or-later
#include <fcntl.h>
#include <gtest/gtest.h>
#include <sys/stat.h>
#include <sys/types.h>

#include <algorithm>
#include <cmath>
#include <fstream>
#include <vector>

#include "DeepPot.h"
#include "neighbor_list.h"
#include "test_utils.h"

// 1e-10 cannot pass; unclear bug or not
#undef EPSILON
#define EPSILON (std::is_same<VALUETYPE, double>::value ? 1e-7 : 1e-1)

template <class VALUETYPE>
class TestInferDeepPotDpaPt : public ::testing::Test {
 protected:
  std::vector<VALUETYPE> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
                                  00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
                                  3.51,  2.51, 2.60, 4.27,  3.22, 1.56};
  std::vector<int> atype = {0, 1, 1, 0, 1, 1};
  std::vector<VALUETYPE> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
  // Generated by the following Python code:
  // import numpy as np
  // from deepmd.infer import DeepPot
  // coord = np.array([
  //     12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
  //     00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
  //     3.51,  2.51, 2.60, 4.27,  3.22, 1.56
  // ]).reshape(1, -1)
  // atype = np.array([0, 1, 1, 0, 1, 1])
  // box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.]).reshape(1, -1)
  // dp = DeepPot("deeppot_dpa.pth")
  // e, f, v, ae, av = dp.eval(coord, box, atype, atomic=True)
  // np.set_printoptions(precision=16)
  // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=}
  // {av.ravel()=}")

  std::vector<VALUETYPE> expected_e = {
      -94.37720733019096, -187.43155959873033, -187.37830241580824,
      -94.34880710985752, -187.38869830422271, -187.33919952642458};
  std::vector<VALUETYPE> expected_f = {
      5.402355596838843,  -1.263284191331685, -0.697693239979719,
      -1.025144852453706, 0.6554396369933394, 0.8817286288078215,
      0.4364579972147229, 1.2150079148857598, -0.6778076371985796,
      -6.939243547937094, 0.1571084862688049, -0.9017435514431825,
      0.3597967524845581, -1.328808718007412, 2.0974306454214653,
      1.7657780538526762, 0.5645368711911929, -0.7019148456078053};
  std::vector<VALUETYPE> expected_v = {
      9.5175137906314511e-01,  -2.0801835688892991e+00, 4.6860789988973117e-01,
      -6.0178723966859824e+00, 1.2556002911926123e-01,  4.7887097832213565e-02,
      5.6216590124464116e-01,  1.7071246159044051e-01,  8.4990129293690209e-02,
      -1.2558035496847255e+00, -3.1123763096053136e-02, -4.4100135935181761e-01,
      6.4707184007995455e-01,  1.5574441384822924e-01,  3.2409058144551339e-01,
      2.8631311270672963e+00,  -3.0375434485037031e-04, 3.9533024424985619e-01,
      3.2722174727830535e+00,  1.1867224518409690e-01,  -2.2250901443705223e-01,
      5.0337980348311300e+00,  6.0517723355290898e-01,  -5.5204995585567707e-01,
      -3.8335680797875722e+00, -2.3083403461022087e-01, 3.1281970616476651e-01,
      -1.0733902445454071e+01, -2.7634498084191517e-01, 1.5720135955951031e+00,
      -2.9262906180354680e+00, 1.0845127764896278e-01,  -1.1142053272645919e-01,
      3.6066832583682209e+00,  -1.9002351752094526e-01, 3.1875602887687587e-01,
      3.6971839777382898e-01,  -2.7352380159430506e-02, 1.0670299036230046e-01,
      1.8155828042674422e+00,  4.9170982983933986e-01,  -6.7166291183351579e-01,
      -2.9003369690467395e+00, -7.6647630459927585e-01, 1.0566933380800889e+00,
      -4.8620953903555858e-01, 4.0440213825136057e-01,  -6.5227187264812003e-01,
      -4.4421997400831864e-01, 1.4811202361724179e-01,  -2.4354470120979710e-01,
      5.3346700156430571e-01,  -1.8977527286286849e-01, 3.1383559345422440e-01};
  int natoms;
  double expected_tot_e;
  std::vector<VALUETYPE> expected_tot_v;

  deepmd::DeepPot dp;

  void SetUp() override {
    dp.init("../../tests/infer/deeppot_dpa.pth");

    natoms = expected_e.size();
    EXPECT_EQ(natoms * 3, expected_f.size());
    EXPECT_EQ(natoms * 9, expected_v.size());
    expected_tot_e = 0.;
    expected_tot_v.resize(9);
    std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
    for (int ii = 0; ii < natoms; ++ii) {
      expected_tot_e += expected_e[ii];
    }
    for (int ii = 0; ii < natoms; ++ii) {
      for (int dd = 0; dd < 9; ++dd) {
        expected_tot_v[dd] += expected_v[ii * 9 + dd];
      }
    }
  };

  void TearDown() override {};
};

TYPED_TEST_SUITE(TestInferDeepPotDpaPt, ValueTypes);

TYPED_TEST(TestInferDeepPotDpaPt, cpu_build_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial;
  dp.compute(ener, force, virial, coord, atype, box);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotDpaPt, cpu_build_nlist_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial, atom_ener, atom_vir;
  dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);
  EXPECT_EQ(atom_ener.size(), natoms);
  EXPECT_EQ(atom_vir.size(), natoms * 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }
}

template <class VALUETYPE>
class TestInferDeepPotDpaPtNopbc : public ::testing::Test {
 protected:
  std::vector<VALUETYPE> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
                                  00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
                                  3.51,  2.51, 2.60, 4.27,  3.22, 1.56};
  std::vector<int> atype = {0, 1, 1, 0, 1, 1};
  std::vector<VALUETYPE> box = {};
  // Generated by the following Python code:
  // import numpy as np
  // from deepmd.infer import DeepPot
  // coord = np.array([
  //     12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
  //     00.25, 3.32, 1.68, 3.36,  3.00, 1.81,
  //     3.51,  2.51, 2.60, 4.27,  3.22, 1.56
  // ]).reshape(1, -1)
  // atype = np.array([0, 1, 1, 0, 1, 1])
  // box = None
  // dp = DeepPot("deeppot_dpa.pth")
  // e, f, v, ae, av = dp.eval(coord, box, atype, atomic=True)
  // np.set_printoptions(precision=16)
  // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=}
  // {av.ravel()=}")

  std::vector<VALUETYPE> expected_e = {
      -95.13216447995296, -188.10146505781867, -187.74742451023172,
      -94.73864717001219, -187.76956603003393, -187.76904550434332};
  std::vector<VALUETYPE> expected_f = {
      0.7486830600282869,  -0.240322915088127,  -0.3943366458127905,
      -0.1776248813665344, 0.2359143394202788,  0.4210018319063822,
      -0.2368532809002255, 0.0291156803500336,  -0.0219651427265617,
      -1.407280069394403,  0.4932116549421467,  -0.9482072853582465,
      -0.1501958909452974, -0.9720722611839484, 1.5128172910814666,
      1.2232710625781733,  0.4541535015596165,  -0.569310049090249};
  std::vector<VALUETYPE> expected_v = {
      1.4724482801774368e+00,  -1.8952544175284314e-01, -2.0502896614522359e-01,
      -2.0361724110178425e-01, 5.4221646102123211e-02,  8.7963957026666373e-02,
      -1.3233356224791937e-01, 8.3907068051133571e-02,  1.6072164570432412e-01,
      2.2913216241740741e+00,  -6.0712170533586352e-02, 1.2802395909429765e-01,
      6.9581050483420448e-03,  2.0894022035588655e-02,  4.3408316864598340e-02,
      -1.4144392402206662e-03, 3.6852652738654124e-02,  7.7149761552687490e-02,
      5.6814285976509526e-01,  -7.0738211182030164e-02, 5.4514470128648518e-02,
      -7.1339324275474125e-02, 9.8158535704203354e-03,  -8.3431069537701560e-03,
      5.4072790262097083e-02,  -8.1976736911977682e-03, 7.6505804915597275e-03,
      1.6869950835783332e-01,  2.1880432930426963e-02,  1.0308234746703970e-01,
      9.1015395953307099e-02,  7.1788910181538768e-02,  -1.4119552688428305e-01,
      -1.4977320631771729e-01, -1.0982955047012899e-01, 2.3324521962640055e-01,
      8.1569862372597679e-01,  6.2848559999917952e-02,  -4.5341405643671506e-02,
      -3.9134119664198064e-01, 4.1651372430088562e-01,  -5.8173709994663803e-01,
      6.6155672230934037e-01,  -6.4774042800560672e-01, 9.0924772156749301e-01,
      2.0503134548416586e+00,  1.9684008914564011e-01,  -3.1711040533580070e-01,
      5.2891751962511613e-01,  8.7385258358844808e-02,  -1.5487618319904839e-01,
      -7.1396830520028809e-01, -1.0977171171532918e-01, 1.9792085656111236e-01};
  int natoms;
  double expected_tot_e;
  std::vector<VALUETYPE> expected_tot_v;

  deepmd::DeepPot dp;

  void SetUp() override {
    dp.init("../../tests/infer/deeppot_dpa.pth");

    natoms = expected_e.size();
    EXPECT_EQ(natoms * 3, expected_f.size());
    EXPECT_EQ(natoms * 9, expected_v.size());
    expected_tot_e = 0.;
    expected_tot_v.resize(9);
    std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
    for (int ii = 0; ii < natoms; ++ii) {
      expected_tot_e += expected_e[ii];
    }
    for (int ii = 0; ii < natoms; ++ii) {
      for (int dd = 0; dd < 9; ++dd) {
        expected_tot_v[dd] += expected_v[ii * 9 + dd];
      }
    }
  };

  void TearDown() override {};
};

TYPED_TEST_SUITE(TestInferDeepPotDpaPtNopbc, ValueTypes);

TYPED_TEST(TestInferDeepPotDpaPtNopbc, cpu_build_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial;
  dp.compute(ener, force, virial, coord, atype, box);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotDpaPtNopbc, cpu_build_nlist_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial, atom_ener, atom_vir;
  dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);
  EXPECT_EQ(atom_ener.size(), natoms);
  EXPECT_EQ(atom_vir.size(), natoms * 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotDpaPtNopbc, cpu_lmp_nlist) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial;

  std::vector<std::vector<int> > nlist_data = {
      {1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5},
      {0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}};
  std::vector<int> ilist(natoms), numneigh(natoms);
  std::vector<int*> firstneigh(natoms);
  deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  dp.compute(ener, force, virial, coord, atype, box, 0, inlist, 0);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
}

TYPED_TEST(TestInferDeepPotDpaPtNopbc, cpu_lmp_nlist_atomic) {
  using VALUETYPE = TypeParam;
  std::vector<VALUETYPE>& coord = this->coord;
  std::vector<int>& atype = this->atype;
  std::vector<VALUETYPE>& box = this->box;
  std::vector<VALUETYPE>& expected_e = this->expected_e;
  std::vector<VALUETYPE>& expected_f = this->expected_f;
  std::vector<VALUETYPE>& expected_v = this->expected_v;
  int& natoms = this->natoms;
  double& expected_tot_e = this->expected_tot_e;
  std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
  deepmd::DeepPot& dp = this->dp;
  double ener;
  std::vector<VALUETYPE> force, virial, atom_ener, atom_vir;

  std::vector<std::vector<int> > nlist_data = {
      {1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5},
      {0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}};
  std::vector<int> ilist(natoms), numneigh(natoms);
  std::vector<int*> firstneigh(natoms);
  deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]);
  convert_nlist(inlist, nlist_data);
  dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box, 0,
             inlist, 0);

  EXPECT_EQ(force.size(), natoms * 3);
  EXPECT_EQ(virial.size(), 9);
  EXPECT_EQ(atom_ener.size(), natoms);
  EXPECT_EQ(atom_vir.size(), natoms * 9);

  EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
  for (int ii = 0; ii < natoms * 3; ++ii) {
    EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
  }
  for (int ii = 0; ii < 3 * 3; ++ii) {
    EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms; ++ii) {
    EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
  }
  for (int ii = 0; ii < natoms * 9; ++ii) {
    EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
  }
}
